#!/usr/bin/env python
# Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Driver for client scripts.

How it Works
============

This script is part of the infrastructure used for resource leak tests.
At a high level, we're trying to verify that specific types of operations
don't leak memory.  For example:

* Creating 100 clients in a loop doesn't leak memory within reason
* Making 100 API calls doesn't leak memory
* Streaming uploads/downloads to/from S3 have O(1) memory usage.

In order to do this, these tests are written to utilize two separate
processes: a driver and a runner.  The driver sends commands to the
runner which performs the actual commands.  While the runner is
running commands the driver can record various attributes about
the runner process.  So far this is just memory usage.

There's a BaseClientDriverTest test that implements the driver part
of the tests.  These should read like normal functional/integration tests.

This script (cmd-runner) implements the runner.  It listens for commands
from the driver and runs the commands.

The driver and runner communicate via a basic text protocol.  Commands
are line separated with arguments separated by spaces.  Commands
are sent to the runner via stdin.

On success, the runner response through stdout with "OK".

Here's an example::

    +--------+                                     +-------+
    | Driver |                                     |Runner |
    +--------+                                     +-------+

         create_client s3
         -------------------------------------------->
               OK
         <-------------------------------------------

         make_aws_request 100 ec2 describe_instances
         -------------------------------------------->
               OK
         <-------------------------------------------

         stream_s3_upload bucket key /tmp/foo.txt
         -------------------------------------------->
               OK
         <-------------------------------------------

         exit
         -------------------------------------------->
               OK
         <-------------------------------------------

"""
import botocore.session
import sys


class CommandRunner(object):
    DEFAULT_REGION = 'us-west-2'

    def __init__(self):
        self._session = botocore.session.get_session()
        self._clients = []
        self._waiters = []
        self._paginators = []

    def run(self):
        while True:
            line = sys.stdin.readline()
            parts = line.strip().split()
            if not parts:
                break
            elif parts[0] == 'exit':
                sys.stdout.write('OK\n')
                sys.stdout.flush()
                break
            else:
                getattr(self, '_do_%s' % parts[0])(parts[1:])
                sys.stdout.write('OK\n')
                sys.stdout.flush()

    def _do_create_client(self, args):
        # The args provided by the user map directly to the args
        # passed to create_client.  At the least you need
        # to provide the service name.
        self._clients.append(self._session.create_client(*args))

    def _do_create_multiple_clients(self, args):
        # Args:
        # * num_clients - Number of clients to create in a loop
        # * client_args - Variable number of args to pass to create_client
        num_clients = args[0]
        client_args = args[1:]
        for _ in range(int(num_clients)):
            self._clients.append(self._session.create_client(*client_args))

    def _do_free_clients(self, args):
        # Frees the memory associated with all clients.
        self._clients = []

    def _do_create_waiter(self, args):
        # Args:
        # [0] client name used in its instantiation
        # [1] waiter name used in its instantiation
        client = self._create_client(args[0])
        self._waiters.append(client.get_waiter(args[1]))

    def _do_create_multiple_waiters(self, args):
        # Args:
        # [0] num_waiters - Number of clients to create in a loop
        # [1] client name used in its instantiation
        # [2] waiter name used in its instantiation
        num_waiters = args[0]
        client = self._create_client(args[1])
        for _ in range(int(num_waiters)):
            self._waiters.append(client.get_waiter(args[2]))

    def _do_free_waiters(self, args):
        # Frees the memory associated with all of the waiters.
        self._waiters = []

    def _do_create_paginator(self, args):
        # Args:
        # [0] client name used in its instantiation
        # [1] paginator name used in its instantiation
        client = self._create_client(args[0])
        self._paginators.append(client.get_paginator(args[1]))

    def _do_create_multiple_paginators(self, args):
        # Args:
        # [0] num_paginators - Number of paginators to create in a loop
        # [1] client name used in its instantiation
        # [2] paginator name used in its instantiation
        num_paginators = args[0]
        client = self._create_client(args[1])
        for _ in range(int(num_paginators)):
            self._paginators.append(client.get_paginator(args[2]))

    def _do_free_paginators(self, args):
        # Frees the memory associated with all of the waiters.
        self._paginators = []

    def _do_make_aws_request(self, args):
        # Create a client and make a number of AWS requests.
        # Args:
        # * num_requests   - The number of requests to create in a loop
        # * service_name   - The name of the AWS service
        # * operation_name - The name of the service, snake_cased
        # * oepration_args - Variable args, kwargs to pass to the API call,
        #                    in the format Kwarg1:Val1 Kwarg2:Val2
        num_requests = int(args[0])
        service_name = args[1]
        operation_name = args[2]
        kwargs = dict([v.split(':', 1) for v in args[3:]])
        client = self._create_client(service_name)
        method = getattr(client, operation_name)
        for _ in range(num_requests):
            method(**kwargs)

    def _do_stream_s3_upload(self, args):
        # Stream an upload to S3 from a local file.
        # This does *not* create an S3 bucket.  You need to create this
        # before running this command.  You will also need to create
        # the local file to upload before calling this command.
        # Args:
        # * bucket   - The name of the S3 bucket
        # * key      - The name of the S3 key
        # * filename - The name of the local filename to upload
        bucket, key, local_filename = args
        client = self._create_client('s3')
        with open(local_filename, 'rb') as f:
            client.put_object(Bucket=bucket, Key=key, Body=f)

    def _do_stream_s3_download(self, args):
        # Stream a download to S3 from a local file.
        # Before calling this command you'll need to create the S3 bucket
        # as well as the S3 object.  Also, the directory where the
        # file will be downloaded must already exist.
        # Args:
        # * bucket   - The name of the S3 bucket
        # * key      - The name of the S3 key
        # * filename - The local filename where the object will be downloaded
        bucket, key, local_filename = args
        client = self._create_client('s3')
        response = client.get_object(Bucket=bucket, Key=key)
        body_stream = response['Body']
        with open(local_filename, 'wb') as f:
            for chunk in iter(lambda: body_stream.read(64 * 1024), b''):
                f.write(chunk)

    def _create_client(self, service_name):
        # Create a client using the provided service name.
        # It will also inject a region name of self.DEFAULT_REGION.
        return self._session.create_client(service_name,
                                           region_name=self.DEFAULT_REGION)


def run():
    CommandRunner().run()


run()
